# Chapter 8 Lab: Decision Trees # Fitting Classification Trees library(tree) library(ISLR) attach(Carseats) ?Carseats fix(Carseats) #Dichotomize Sales High=ifelse(Sales<=8,"No","Yes") Carseats=data.frame(Carseats,High) #syntax of tree function is similar to lm function tree.carseats=tree(High~.-Sales,Carseats) #variables used in tree construction, #of nodes, error rate #a small deviance indicates good fit to the training data summary(tree.carseats) #plot the tree structure plot(tree.carseats) #display node labels text(tree.carseats,pretty=0) #We see the most important predictor of sales is shelving location #display split criterion (e.g. Price<92.5), the #number of observations in that branch, the deviance, the overall prediction #for the branch (Yes or No), and the fraction of observations in that branch #that take on values of Yes and No tree.carseats #we should use test error to evaluate the performance of a tree set.seed(2) train=sample(1:nrow(Carseats), 200) Carseats.test=Carseats[-train,] High.test=High[-train] tree.carseats=tree(High~.-Sales,Carseats,subset=train) tree.pred=predict(tree.carseats,Carseats.test,type="class") table(tree.pred,High.test) (86+57)/200 #correct prediction #pruning the tree via cross validation set.seed(3) #FUN.prune.misclass: use classification error rate in the pruning process #default is deviance cv.carseats=cv.tree(tree.carseats,FUN=prune.misclass) #size: number of terminal nodes (leaves) #k: alpha in the slides #dev: cross-validation error rate cv.carseats #The tree with 9 terminal nodes results in the lowest #cross-validation error rate, with 50 cross-validation errors #We plot the error rate as a function of both size and k par(mfrow=c(1,2)) plot(cv.carseats$size,cv.carseats$dev,type="b") plot(cv.carseats$k,cv.carseats$dev,type="b") #We now apply the prune.misclass() function in order to prune the tree to prune. #obtain the nine-node tree. prune.carseats=prune.misclass(tree.carseats,best=9) plot(prune.carseats) text(prune.carseats,pretty=0) #Compute test error rate tree.pred=predict(prune.carseats,Carseats.test,type="class") table(tree.pred,High.test) (94+60)/200 # correct prediction # Fitting Regression Trees to the boston data library(MASS) ?Boston #split half to be training data set.seed(1) train = sample(1:nrow(Boston), nrow(Boston)/2) #grow a large tree using training data #medv: median value of homes in $1000 tree.boston=tree(medv~.,Boston,subset=train) summary(tree.boston) #only 3 variables were used #plot the tree. plot(tree.boston) text(tree.boston,pretty=0) #Pruning the tree cv.boston=cv.tree(tree.boston) plot(cv.boston$size,cv.boston$dev,type='b') #the most complex tree is selected by cross-validation #we use the unpruned tree to make predictions on the test set yhat=predict(tree.boston,newdata=Boston[-train,]) boston.test=Boston[-train,"medv"] plot(yhat,boston.test) abline(0,1) sqrt(mean((yhat-boston.test)^2)) #this model leads to test predictions that are within around $5, 005 of #the true median home value for the suburb.